import torch
import torch.nn.functional as F
from GAT2 import GAT
import env
env = env.Env1(10)


link_state,edge_index,demand,allPathsCode = env.reset()
model=GAT(21,10,50,10,3)
y=model( torch.Tensor(link_state) ,torch.IntTensor(edge_index) )
loss=F.mse_loss(y,torch.FloatTensor([2]))
loss.backward()
print(y)